/**
 * Copyright 2023-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "trans_tensor.h"

#include "infra/math/fp16_t.h"
#include "framework/infra/log/log.h"

#if defined(ARM_NEON_32)
#include <arm_neon.h>
#endif

using namespace std;
using namespace ge;
namespace hiai {
#ifdef ARM_NEON
#define HalfToFloat(inPtr, outPtr) \
    do { \
        __asm__ __volatile( \
            "mov x2, %[inPtr]\n" \
            "mov x3, %[outPtr]\n" \
            "mov x4, #17\n" \
            "lsl x4, x4, #1\n" \
            "ld1 {v9.4h}, [x2], #8\n" \
            "fcvtl v10.4s, v9.4h\n" \
            "ld1 {v9.4h}, [x2], #8\n" \
            "st1 {v10.4s}, [x3], #16\n" \
            "fcvtl v10.4s, v9.4h\n" \
            "ld1 {v9.4h}, [x2], #8\n" \
            "st1 {v10.4s}, [x3], #16\n" \
            "fcvtl v10.4s, v9.4h\n" \
            "ld1 {v9.4h}, [x2], #8\n" \
            "st1 {v10.4s}, [x3], #16\n" \
            "fcvtl v10.4s, v9.4h\n" \
            "st1 {v10.4s}, [x3], #16\n" \
            : [outPtr] "+r"(outPtr) \
            : [inPtr] "r"(inPtr) \
            : "x2", "x3", "x4", "x5", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v18", "s17", "cc"); \
    } while (0)
#endif

Status TransTensorHALFToFloat(const ccTensor_t& xDesc, const void* x, const ccTensor_t& yDesc, void* y)
{
    uint32_t dataCnt = xDesc.dataSize / sizeof(fp16_t);
    if (yDesc.dataSize < dataCnt * sizeof(float)) {
        FMK_LOGE("outputDataSize:%u not enough!", yDesc.dataSize);
        return FAILED;
    }

#if defined(ARM_NEON)
    uint32_t loopTime = dataCnt / 16;
    uint32_t lastTime = dataCnt % 16;
    uint64_t inPtr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(x)) + lastTime * 2;
    uint64_t outPtr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(y)) + lastTime * 4;

    for (uint32_t i = 0; i < loopTime; i++) {
        HalfToFloat(inPtr, outPtr);
        inPtr = inPtr + 32;
        outPtr = outPtr + 64;
    }
    for (uint32_t i = 0; i < lastTime; i++) {
        fp16_t fp16Data = static_cast<const uint16_t*>(x)[i];
        float fp32Data = fp16Data;
        static_cast<float*>(y)[i] = fp32Data;
    }
#elif defined(ARM_NEON_32)
    uint32_t loopTime = dataCnt / 4;
    uint32_t lastTime = dataCnt % 4;
    uint64_t inPtr = (uint64_t)((uintptr_t)x) + lastTime * 2;
    uint64_t outPtr = (uint64_t)((uintptr_t)y) + lastTime * 4;

    float16x4_t tmpfp16;
    float32x4_t tmpfp32;
    for (uint32_t i = 0; i < loopTime; i++) {
        tmpfp16 = vld1_f16(reinterpret_cast<__fp16*>(inPtr));
        tmpfp32 = vcvt_f32_f16(tmpfp16);
        vst1q_f32(reinterpret_cast<float32_t*>(outPtr), tmpfp32);
        inPtr = inPtr + 8;
        outPtr = outPtr + 16;
    }
    for (uint32_t i = 0; i < lastTime; i++) {
        fp16_t fp16Data = static_cast<const uint16_t*>(x)[i];
        float fp32Data = fp16Data;
        static_cast<float*>(y)[i] = fp32Data;
    }
#else
    for (uint32_t i = 0; i < dataCnt; i++) {
        fp16_t fp16Data = static_cast<const uint16_t*>(x)[i];
        float fp32Data = fp16Data;
        static_cast<float*>(y)[i] = fp32Data;
    }
#endif
    return SUCCESS;
}
} // namesapece ge
